{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.sys.path.append(os.path.dirname(os.path.abspath('.')))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(455, 30) (114, 30) (455,) (114,)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "from datasets.dataset import load_breast_cancer\n",
    "data = load_breast_cancer()\n",
    "X, Y = data.data, data.target\n",
    "del data\n",
    "\n",
    "from model_selection.train_test_split import train_test_split\n",
    "X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2)\n",
    "print(X_train.shape, X_test.shape, Y_train.shape, Y_test.shape)\n",
    "\n",
    "# 把X，Y拼起来便于操作\n",
    "training_data = np.c_[X_train, Y_train]\n",
    "testing_data = np.c_[X_test, Y_test]\n",
    "\n",
    "# print(training_data.shape,testing_data.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型基础\n",
    "RF跟普通树模型的区别很明显也很简单，每棵树在一个随机抽样的子数据集上训练，并且每次分裂时只在一个随机子空间上做test。为了简便，在抽样数据子集时同时随机选取$\\sqrt{m}$个子特征。\n",
    "\n",
    "注意，为了保持训练与预测时数据的一致性，这里没有丢弃未抽到的特征，而是将未抽到的特征列全部置零，相当于做了一个掩盖操作，不过弊端就是内存占用大。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def RandomPatches(data):\n",
    "    '''\n",
    "    随机抽样函数，同时对样本与特征抽样\n",
    "    '''\n",
    "    n_samples, n_features = data.shape\n",
    "    n_features -= 1\n",
    "    sub_data = np.copy(data)\n",
    "\n",
    "    random_f_idx = np.random.choice(\n",
    "        n_features, size=int(np.sqrt(n_features)), replace=False)\n",
    "    mask_f_idx = [i for i in range(\n",
    "        n_features) if i not in random_f_idx]    # 未抽到的特征idx\n",
    "\n",
    "    random_data_idx = np.random.choice(n_samples, size=n_samples, replace=True)\n",
    "    sub_data = data[random_data_idx]\n",
    "    sub_data[:, mask_f_idx] = 0    # 未抽到的特征列全部置零\n",
    "    return sub_data\n",
    "\n",
    "# RandomPatches(training_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后就可以实现一个简单的串行版本RF模型了。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.sys.path.append(os.path.dirname(os.path.abspath('.')))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tree.DecisionTreeClassifier import DecisionTreeClassifier\n",
    "\n",
    "\n",
    "def RF_Clf(data, n_estimators=5):\n",
    "    trees = []\n",
    "\n",
    "    for _ in range(n_estimators):\n",
    "        tree = DecisionTreeClassifier()\n",
    "        sub_data = RandomPatches(data)\n",
    "        tree.fit(sub_data[:, :-1], sub_data[:, -1])\n",
    "        trees.append(tree)\n",
    "\n",
    "    return trees\n",
    "\n",
    "\n",
    "trees = RF_Clf(training_data)\n",
    "# print(trees)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "串行预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tree_0 acc:0.8508771929824561\n",
      "tree_1 acc:0.9385964912280702\n",
      "tree_2 acc:0.9122807017543859\n",
      "tree_3 acc:0.956140350877193\n",
      "tree_4 acc:0.9122807017543859\n",
      "rf acc:0.9385964912280702\n"
     ]
    }
   ],
   "source": [
    "from scipy import stats    # numpy未提供mode方法，借助scipy\n",
    "\n",
    "\n",
    "def predict(X_test, trees):\n",
    "    raw_pred = np.array([tree.predict(X_test) for tree in trees]).T\n",
    "    return raw_pred    # 返回原始结果\n",
    "#     return np.array([stats.mode(y_pred)[0][0] for y_pred in raw_pred])    # 返回投票结果\n",
    "\n",
    "\n",
    "Y_pred = predict(testing_data[:, :-1], trees)\n",
    "\n",
    "# 输出每一棵树的单独预测准确率\n",
    "for i in range(len(trees)):\n",
    "    cur_pred = Y_pred[:, i]\n",
    "    print('tree_{} acc:{}'.format(i, np.sum(cur_pred == Y_test)/len(Y_test)))\n",
    "\n",
    "# 输出RF投票后的准确率\n",
    "vote_pred = np.array([stats.mode(y_pred)[0][0] for y_pred in Y_pred])\n",
    "print('rf acc:{}'.format(np.sum(vote_pred == Y_test)/len(Y_test)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过上述结果可以看到(可多次运行)，RF的准确率不会低于任意一颗单CART树。然后对比sklearn中的RF准确率："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sklearn acc:0.9473684210526315\n"
     ]
    }
   ],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "rf_clf = RandomForestClassifier(n_estimators=5)\n",
    "rf_clf.fit(X_train, Y_train)\n",
    "Y_pred = rf_clf.predict(X_test)\n",
    "print('sklearn acc:{}'.format(np.sum(Y_pred == Y_test)/len(Y_test)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由于没做过Python开发，对Python的并发操作并不是很了解，这里还未实现RF的并行训练，待后期补充。"
   ]
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
